package de.lmu.ifi.dbs.elki.algorithm.benchmark;

import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.QueryUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRange;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.DatabaseQuery;
import de.lmu.ifi.dbs.elki.database.query.LinearScanQuery;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.DatabaseUtil;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Parameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.PatternParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
import java.util.regex.Pattern;

/* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/benchmark/ValidateApproximativeKNNIndex.class */
public class ValidateApproximativeKNNIndex<O> extends AbstractDistanceBasedAlgorithm<O, Result> {
    private static final Logging LOG;
    protected int k;
    protected DatabaseConnection queries;
    protected double sampling;
    protected boolean forcelinear;
    protected RandomFactory random;
    protected Pattern pattern;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/benchmark/ValidateApproximativeKNNIndex$Parameterizer.class */
    public static class Parameterizer<O> extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
        public static final OptionID K_ID = new OptionID("validateknn.k", "Number of neighbors to retreive for kNN benchmarking.");
        public static final OptionID QUERY_ID = new OptionID("validateknn.query", "Data source for the queries. If not set, the queries are taken from the database.");
        public static final OptionID SAMPLING_ID = new OptionID("validateknn.sampling", "Sampling size parameter. If the value is less or equal 1, it is assumed to be the relative share. Larger values will be interpreted as integer sizes. By default, all data will be used.");
        public static final OptionID FORCE_ID = new OptionID("validateknn.force-linear", "Force the use of linear scanning as reference.");
        public static final OptionID RANDOM_ID = new OptionID("validateknn.random", "Random generator for sampling.");
        public static final OptionID PATTERN_ID = new OptionID("validateknn.pattern", "Pattern to select query points.");
        protected int k = 10;
        protected DatabaseConnection queries = null;
        protected double sampling = -1.0d;
        protected boolean forcelinear = false;
        protected RandomFactory random;
        protected Pattern pattern;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            IntParameter intParameter = new IntParameter(K_ID);
            if (parameterization.grab(intParameter)) {
                this.k = intParameter.intValue();
            }
            Parameter<?> patternParameter = new PatternParameter(PATTERN_ID);
            patternParameter.setOptional(true);
            if (parameterization.grab(patternParameter)) {
                this.pattern = patternParameter.getValue();
            } else {
                ObjectParameter objectParameter = new ObjectParameter(QUERY_ID, DatabaseConnection.class);
                objectParameter.setOptional(true);
                if (parameterization.grab(objectParameter)) {
                    this.queries = (DatabaseConnection) objectParameter.instantiateClass(parameterization);
                }
            }
            DoubleParameter doubleParameter = new DoubleParameter(SAMPLING_ID);
            doubleParameter.setOptional(true);
            if (parameterization.grab(doubleParameter)) {
                this.sampling = doubleParameter.doubleValue();
            }
            Flag flag = new Flag(FORCE_ID);
            if (parameterization.grab(flag)) {
                this.forcelinear = flag.isTrue();
            }
            Parameter<?> randomParameter = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT);
            if (parameterization.grab(randomParameter)) {
                this.random = randomParameter.getValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public ValidateApproximativeKNNIndex<O> makeInstance() {
            return new ValidateApproximativeKNNIndex<>(this.distanceFunction, this.k, this.queries, this.sampling, this.forcelinear, this.random, this.pattern);
        }
    }

    public ValidateApproximativeKNNIndex(DistanceFunction<? super O> distanceFunction, int i, DatabaseConnection databaseConnection, double d, boolean z, RandomFactory randomFactory, Pattern pattern) {
        super(distanceFunction);
        this.k = 10;
        this.queries = null;
        this.sampling = -1.0d;
        this.forcelinear = false;
        this.k = i;
        this.queries = databaseConnection;
        this.sampling = d;
        this.forcelinear = z;
        this.random = randomFactory;
        this.pattern = pattern;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Result run(Database database, Relation<O> relation) {
        DistanceQuery<O> distanceQuery = database.getDistanceQuery(relation, getDistanceFunction(), new Object[0]);
        KNNQuery kNNQuery = database.getKNNQuery(distanceQuery, Integer.valueOf(this.k), DatabaseQuery.HINT_OPTIMIZED_ONLY);
        if (kNNQuery == 0 || (kNNQuery instanceof LinearScanQuery)) {
            throw new AbortException("Expected an accelerated query, but got a linear scan -- index is not used.");
        }
        KNNQuery<O> linearScanKNNQuery = this.forcelinear ? QueryUtil.getLinearScanKNNQuery(distanceQuery) : database.getKNNQuery(distanceQuery, Integer.valueOf(this.k), "exact");
        if (kNNQuery.getClass().equals(linearScanKNNQuery.getClass())) {
            LOG.warning("Query classes are the same. This experiment may be invalid!");
        }
        if (this.queries == null || this.pattern != null) {
            Relation<String> guessLabelRepresentation = this.pattern != null ? DatabaseUtil.guessLabelRepresentation(database) : null;
            DBIDs randomSample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
            FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", randomSample.size(), LOG) : null;
            MeanVariance meanVariance = new MeanVariance();
            MeanVariance meanVariance2 = new MeanVariance();
            MeanVariance meanVariance3 = new MeanVariance();
            MeanVariance meanVariance4 = new MeanVariance();
            MeanVariance meanVariance5 = new MeanVariance();
            int i = 0;
            DBIDIter iter = randomSample.iter();
            while (iter.valid()) {
                if (this.pattern == null || this.pattern.matcher(guessLabelRepresentation.get(iter)).find()) {
                    KNNList kNNForDBID = kNNQuery.getKNNForDBID(iter, this.k);
                    KNNList kNNForDBID2 = linearScanKNNQuery.getKNNForDBID(iter, this.k);
                    meanVariance.put((kNNForDBID.size() * this.k) / kNNForDBID2.size());
                    meanVariance2.put(DBIDUtil.intersectionSize(kNNForDBID, kNNForDBID2) / kNNForDBID2.size());
                    if (kNNForDBID.size() >= this.k) {
                        double kNNDistance = kNNForDBID.getKNNDistance();
                        double kNNDistance2 = kNNForDBID2.getKNNDistance();
                        if (kNNDistance2 > 0.0d) {
                            meanVariance3.put(kNNDistance);
                            meanVariance4.put(kNNDistance - kNNDistance2);
                            meanVariance5.put(kNNDistance / kNNDistance2);
                        }
                    } else {
                        i++;
                    }
                }
                LOG.incrementProcessed(finiteProgress);
                iter.advance();
            }
            LOG.ensureCompleted(finiteProgress);
            if (!LOG.isStatistics()) {
                return null;
            }
            LOG.statistics("Mean number of results: " + meanVariance.getMean() + " +- " + meanVariance.getNaiveStddev());
            LOG.statistics("Recall of true results: " + meanVariance2.getMean() + " +- " + meanVariance2.getNaiveStddev());
            if (meanVariance3.getCount() > 0.0d) {
                LOG.statistics("Mean k-distance: " + meanVariance3.getMean() + " +- " + meanVariance3.getNaiveStddev());
                LOG.statistics("Mean absolute k-error: " + meanVariance4.getMean() + " +- " + meanVariance4.getNaiveStddev());
                LOG.statistics("Mean relative k-error: " + meanVariance5.getMean() + " +- " + meanVariance5.getNaiveStddev());
            }
            if (i <= 0) {
                return null;
            }
            LOG.statistics(String.format("Number of queries that returned less than k=%d objects: %d (%.2f%%)", Integer.valueOf(this.k), Integer.valueOf(i), Double.valueOf((i * 100.0d) / meanVariance.getCount())));
            return null;
        }
        TypeInformation inputTypeRestriction = getDistanceFunction().getInputTypeRestriction();
        MultipleObjectsBundle loadData = this.queries.loadData();
        int i2 = -1;
        int i3 = 0;
        while (true) {
            if (i3 >= loadData.metaLength()) {
                break;
            }
            if (inputTypeRestriction.isAssignableFromType(loadData.meta(i3))) {
                i2 = i3;
                break;
            }
            i3++;
        }
        if (i2 < 0) {
            throw new AbortException("No compatible data type in query input was found. Expected: " + inputTypeRestriction.toString());
        }
        DBIDRange generateStaticDBIDRange = DBIDUtil.generateStaticDBIDRange(loadData.dataLength());
        DBIDs randomSample2 = DBIDUtil.randomSample(generateStaticDBIDRange, this.sampling, this.random);
        FiniteProgress finiteProgress2 = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", randomSample2.size(), LOG) : null;
        MeanVariance meanVariance6 = new MeanVariance();
        MeanVariance meanVariance7 = new MeanVariance();
        MeanVariance meanVariance8 = new MeanVariance();
        MeanVariance meanVariance9 = new MeanVariance();
        MeanVariance meanVariance10 = new MeanVariance();
        int i4 = 0;
        DBIDIter iter2 = randomSample2.iter();
        while (iter2.valid()) {
            int binarySearch = generateStaticDBIDRange.binarySearch(iter2);
            if (!$assertionsDisabled && binarySearch < 0) {
                throw new AssertionError();
            }
            Object data = loadData.data(binarySearch, i2);
            KNNList kNNForObject = kNNQuery.getKNNForObject(data, this.k);
            KNNList kNNForObject2 = linearScanKNNQuery.getKNNForObject(data, this.k);
            meanVariance6.put((kNNForObject.size() * this.k) / kNNForObject2.size());
            meanVariance7.put(DBIDUtil.intersectionSize(kNNForObject, kNNForObject2) / kNNForObject2.size());
            if (kNNForObject.size() >= this.k) {
                double kNNDistance3 = kNNForObject.getKNNDistance();
                double kNNDistance4 = kNNForObject2.getKNNDistance();
                if (kNNDistance4 > 0.0d) {
                    meanVariance8.put(kNNDistance3);
                    meanVariance9.put(kNNDistance3 - kNNDistance4);
                    meanVariance10.put(kNNDistance3 / kNNDistance4);
                }
            } else {
                i4++;
            }
            LOG.incrementProcessed(finiteProgress2);
            iter2.advance();
        }
        LOG.ensureCompleted(finiteProgress2);
        if (!LOG.isStatistics()) {
            return null;
        }
        LOG.statistics("Mean number of results: " + meanVariance6.getMean() + " +- " + meanVariance6.getNaiveStddev());
        LOG.statistics("Recall of true results: " + meanVariance7.getMean() + " +- " + meanVariance7.getNaiveStddev());
        if (meanVariance8.getCount() > 0.0d) {
            LOG.statistics("Mean absolute k-error: " + meanVariance9.getMean() + " +- " + meanVariance9.getNaiveStddev());
            LOG.statistics("Mean relative k-error: " + meanVariance10.getMean() + " +- " + meanVariance10.getNaiveStddev());
        }
        if (i4 <= 0) {
            return null;
        }
        LOG.statistics(String.format("Number of queries that returned less than k=%d objects: %d (%.2f%%)", Integer.valueOf(this.k), Integer.valueOf(i4), Double.valueOf((i4 * 100.0d) / meanVariance6.getCount())));
        return null;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(getDistanceFunction().getInputTypeRestriction());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm
    public Logging getLogger() {
        return LOG;
    }

    static {
        $assertionsDisabled = !ValidateApproximativeKNNIndex.class.desiredAssertionStatus();
        LOG = Logging.getLogger((Class<?>) ValidateApproximativeKNNIndex.class);
    }
}
